import numpy as np
import pandas as p
from tqdm import tqdm
from DetermineComparisonStates import load_household_data, getVehicleDataByYear
from copy import deepcopy



def add_prefix_col_names(data, prefix):
    new_names = []

    for col in data.columns:
        if(col != 'HOUSEID'):
            new_names.append(prefix + col)
        else:
            new_names.append(col)

    data.rename(dict(zip(data.columns, new_names)), axis='columns', inplace=True)

    return data

def load_vehicle_data(year=2017, small=False):
    veh_data = getVehicleDataByYear(year, small)

    veh_data.loc[:, '_miles_count'] = 1
    veh_data.loc[:, '_age_count'] = 1
    veh_data.loc[:, '_year_count'] = 1

    veh_data.loc[veh_data.loc[:, 'ANNMILES'] < 0, '_miles_count'] = 0
    veh_data.loc[veh_data.loc[:, 'VEHAGE'] < 0, '_age_count'] = 0
    veh_data.loc[veh_data.loc[:, 'VEHYEAR'] < 0, '_year_count'] = 0

    veh_data.loc[veh_data.loc[:, 'ANNMILES'] < 0, 'ANNMILES'] = np.nan
    veh_data.loc[veh_data.loc[:, 'VEHAGE'] < 0, 'VEHAGE'] = np.nan
    veh_data.loc[veh_data.loc[:, 'VEHYEAR'] < 0, 'VEHYEAR'] = np.nan



    veh_data.loc[:, 'new_vehicle'] = 0
    veh_data.loc[:, 'newly_bought_vehicle'] = 0

    veh_data.loc[(veh_data.loc[:, 'VEHOWNED']==2)&(veh_data.loc[:, 'VEHAGE']==1), 'new_vehicle'] = 1
    veh_data.loc[(veh_data.loc[:, 'VEHOWNED'] == 2), 'newly_bought_vehicle'] = 1
    # return veh_data
    min_veh_data = veh_data.groupby('HOUSEID', as_index=False).min()
    min_veh_data = add_prefix_col_names(min_veh_data, 'min')

    max_veh_data = veh_data.groupby('HOUSEID', as_index=False).max()
    max_veh_data = add_prefix_col_names(max_veh_data, 'max')

    sum_veh_data = veh_data.groupby('HOUSEID', as_index=False).sum()
    sum_veh_data = add_prefix_col_names(sum_veh_data, 'sum')

    mean_veh_data = veh_data.groupby('HOUSEID', as_index=False).mean()
    mean_veh_data = add_prefix_col_names(mean_veh_data, 'mean')


    full_data = deepcopy(min_veh_data)
    full_data = full_data.merge(max_veh_data, on='HOUSEID', how='outer')
    full_data = full_data.merge(sum_veh_data, on='HOUSEID', how='outer')
    full_data = full_data.merge(mean_veh_data, on='HOUSEID', how='outer')

    #Make sure points without data aren't accidentally recorded as zero
    # - total annual miles is set to zero as we assume that baseline travel is
    #zero miles
    full_data.loc[full_data.loc[:, 'sum_miles_count'] == 0, 'sumANNMILES'] = 0
    full_data.loc[full_data.loc[:, 'sum_miles_count'] == 0, 'minANNMILES'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_miles_count'] == 0, 'maxANNMILES'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_miles_count'] == 0, 'meanANNMILES'] = np.nan

    full_data.loc[full_data.loc[:, 'sum_age_count'] == 0, 'sumVEHAGE'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_age_count'] == 0, 'minVEHAGE'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_age_count'] == 0, 'maxVEHAGE'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_age_count'] == 0, 'meanVEHAGE'] = np.nan

    full_data.loc[full_data.loc[:, 'sum_year_count'] == 0, 'sumVEHYEAR'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_year_count'] == 0, 'minVEHYEAR'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_year_count'] == 0, 'maxVEHYEAR'] = np.nan
    full_data.loc[full_data.loc[:, 'sum_year_count'] == 0, 'meanVEHYEAR'] = np.nan

    return_cols = ['HOUSEID', 'maxVEHID', 'minVEHAGE', 'maxVEHAGE', 'meanVEHAGE', 'minHHSTATE', 'maxHHSTATE', 'sumANNMILES', 'meanANNMILES', 'minANNMILES', 'maxANNMILES', 'minVEHYEAR', 'maxVEHYEAR', 'meannew_vehicle', 'sumnew_vehicle', 'sumnewly_bought_vehicle', 'sum_miles_count', 'sum_age_count', 'sum_year_count']

    return_data = full_data.loc[:, return_cols]



    return_data.rename({'maxVEHID':'num_vehicles', 'minVEHAGE':'min_veh_age', 'maxVEHAGE':'max_veh_age', 'meanVEHAGE':'mean_veh_age', 'minHHSTATE':'State Code', 'sumANNMILES':'sum_total_miles', 'meanANNMILES':'miles_per_veh', 'minVEHYEAR':'min_model_year', 'maxVEHYEAR':'max_model_year', 'meannew_vehicle':'share_new_veh', 'sumnew_vehicle':'num_new_veh', 'sumnewly_bought_vehicle':'num_newly_bought_veh', 'maxHHSTATE':'alt_state_code', 'minANNMILES':'min_miles', 'maxANNMILES':'max_miles'}, axis='columns', inplace=True)



    min_veh_age_negative_indices = return_data.loc[:, 'min_veh_age'] < 0
    min_veh_miles_negative_indices = return_data.loc[:, 'min_miles'] < 0
    min_veh_model_year_negative_indices = return_data.loc[:, 'min_model_year'] < 0

    return_data.loc[min_veh_age_negative_indices, 'min_veh_age'] = np.nan
    return_data.loc[min_veh_age_negative_indices, 'mean_veh_age'] = np.nan
    return_data.loc[min_veh_miles_negative_indices, 'min_miles'] = np.nan
    return_data.loc[min_veh_miles_negative_indices, 'miles_per_veh'] = np.nan
    return_data.loc[min_veh_model_year_negative_indices, 'min_model_year'] = np.nan

    return_data.loc[:, 'single_vehicle_mean_age'] = np.nan
    return_data.loc[return_data.loc[:, 'num_vehicles'] == 1, 'single_vehicle_mean_age'] = return_data.loc[return_data.loc[:, 'num_vehicles'] == 1, 'mean_veh_age']
    return_data.loc[:, 'multi_vehicle_mean_age'] = np.nan
    return_data.loc[return_data.loc[:, 'num_vehicles'] > 1, 'multi_vehicle_mean_age'] = return_data.loc[
        return_data.loc[:, 'num_vehicles'] > 1, 'mean_veh_age']

    return_data.loc[:, 'multi_vehicle_num_vehicles'] = np.nan
    return_data.loc[return_data.loc[:, 'num_vehicles'] > 1, 'multi_vehicle_num_vehicles'] = return_data.loc[
        return_data.loc[:, 'num_vehicles'] > 1, 'num_vehicles']

    return_data.loc[:, 'all_vehicle_num_vehicles'] = 0
    return_data.loc[return_data.loc[:, 'num_vehicles'] >= 0, 'all_vehicle_num_vehicles'] = return_data.loc[
        return_data.loc[:, 'num_vehicles'] >= 0, 'num_vehicles']

    return return_data

def group_household_data_by_state(data, columns, weight_col):

    return_cols = ['State Code']
    return_cols.extend(columns)
    return_cols.append(weight_col)

    data.loc[:, 'count'] = 1

    # if('HasSafety' in columns):
    #     data.loc[:, 'HasSafety'] = data.loc[:, 'HasSafety'].astype(int)

    for col in columns:
        temp_sum_col = f'{col}-weighted-sum'
        temp_weight_col = f'{col}-weight'
        return_cols.append(temp_sum_col)
        return_cols.append(temp_weight_col)
        data.loc[:, temp_weight_col] = data.loc[:, weight_col]
        data.loc[data.loc[:, col].isna(), temp_weight_col] = np.nan
        data.loc[data.loc[:, col]<0, col] = np.nan
        data.loc[:, temp_sum_col] = data.loc[:, col]*data.loc[:, temp_weight_col]


    grouped_data = data.groupby('State Code',as_index=False).sum()

    for col in columns:
        temp_sum_col = f'{col}-weighted-sum'
        temp_weighted_col = f'{col}-weighted'
        temp_weight_col = f'{col}-weight'
        return_cols.append(temp_weighted_col)
        try:
            grouped_data.loc[:, temp_weighted_col] = grouped_data.loc[:, temp_sum_col]/grouped_data.loc[:, temp_weight_col]
        except KeyError:
            print(f'Cannot find {col}')

    print(return_cols)
    return grouped_data.loc[:, return_cols], data


def get_oldest_vehicle_travel(data):
    house_id_set = set(data.loc[:, 'HOUSEID'])
    keep_cols = ['HOUSEID', 'VEHAGE', 'ANNMILES']
    clean_names = ['HOUSEID', 'oldest_veh_age', 'oldest_veh_travel']
    rename_dict = dict(zip(keep_cols, clean_names))
    r = []
    for house_id in tqdm(house_id_set):
        sub_data = data.loc[data.loc[:, 'HOUSEID'] == house_id, :]
        sub_data.loc[sub_data.loc[:, 'ANNMILES']<0, 'ANNMILES'] = np.nan
        sub_data.sort_values(['VEHAGE', 'ANNMILES'], inplace=True, ascending=[False, True])
        r.append(sub_data.iloc[[0], :])
    r = p.concat(r)
    r.sort_values(['VEHAGE', 'ANNMILES'], inplace=True, ascending=[False, True])
    r = r.loc[:, keep_cols]
    r.rename(columns=rename_dict, inplace=True)
    return r


if __name__=='__main__':

    small_set = ['HHSIZE', 'DRVRCNT', 'Income', 'HBPPOPDN', 'HasSafety', 'TreatmentYear']
    large_set = deepcopy(small_set)
    large_set.extend(['WRKCOUNT', 'NUMADLT', 'YOUNGCHILD', 'Age'])
    veh_cols = ['num_vehicles', 'min_veh_age', 'max_veh_age', 'mean_veh_age', 'min_model_year', 'max_model_year', 'share_new_veh', 'num_new_veh', 'sum_total_miles', 'miles_per_veh', 'min_miles', 'max_miles', 'share_new_veh', 'single_vehicle_mean_age', 'multi_vehicle_mean_age', 'multi_vehicle_num_vehicles', 'all_vehicle_num_vehicles', 'num_newly_bought_veh', 'oldest_veh_travel', 'all_hh_new_veh', 'all_hh_newly_bought_veh']
    large_set.extend(veh_cols)

    weight_col = 'WTHHFIN'
    years = [2009, 2017]
    for year in tqdm(years):
        hh_veh_data = load_vehicle_data(year=year)
        veh_data = getVehicleDataByYear(year, False)
        old_veh_data = get_oldest_vehicle_travel(veh_data)
        hh_veh_data = hh_veh_data.merge(old_veh_data, how='outer', indicator=False)

        hh_data = load_household_data(small=False, window=0, year=year)


        joint_data = hh_data.merge(hh_veh_data, on='HOUSEID', how='outer', indicator=True, suffixes=('', '_veh'))
        joint_data.loc[:, 'all_hh_new_veh'] = 0
        joint_data.loc[:, 'all_hh_newly_bought_veh'] = 0

        joint_data.loc[~joint_data.loc[:, 'num_new_veh'].isna(), 'all_hh_new_veh'] = joint_data.loc[~joint_data.loc[:, 'num_new_veh'].isna(), 'num_new_veh']
        joint_data.loc[~joint_data.loc[:, 'num_new_veh'].isna(), 'all_hh_newly_bought_veh'] = joint_data.loc[~joint_data.loc[:, 'num_newly_bought_veh'].isna(), 'num_newly_bought_veh']
        left_only_indices = joint_data.loc[:, '_merge']=='left_only'

        joint_data.loc[left_only_indices, 'num_vehicles'] = 0
        joint_data.loc[left_only_indices, 'sum_total_miles'] = 0
        joint_data.loc[left_only_indices, 'oldest_veh_travel'] = np.nan

        # a_grouped_sum = a_grouped.loc[:, ['State Code', 'WTHHFIN', 'YOUNGCHILD', 'YOUNGCHILD-weighted-sum', 'YOUNGCHILD-weighted']]

        hh_data_grouped, full_hh_data = group_household_data_by_state(joint_data, large_set, weight_col)

        keep_cols = []

        for col in hh_data_grouped.columns:
            if(('-weighted' in col and 'sum' not in col) or 'Code' in col or 'sum_total_miles' in col):
                keep_cols.append(col)

        hh_data_grouped_small = hh_data_grouped.loc[:, keep_cols]
        hh_data_grouped_small.loc[:, 'TreatmentYear-weighted'] = np.round(hh_data_grouped_small.loc[:, 'TreatmentYear-weighted'], 0)
        # hh_data_grouped_small.loc[:, 'TreatmentYear-weighted'] = hh_data_grouped_small.loc[:, 'TreatmentYear-weighted'].astype(int)
        hh_data_grouped_small.loc[:, 'HasSafety-weighted'] = hh_data_grouped_small.loc[:,
                                                                 'HasSafety-weighted'].astype(int)
        hh_data_grouped_small.to_csv(f'CleanData/NHTS{year}.csv', index=False)
        full_hh_data.to_csv(f'CleanData/FullNHTS{year}.csv', index=False)